# ======================================================================
# Copyright CERFACS (June 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

from qat.lang.AQASM.program import Program
from qat.lang.AQASM.routines import QRoutine


def get_circuit(routine: QRoutine, **to_circ_kwargs):
    prog = Program()
    qubits = prog.qalloc(routine.arity)
    prog.apply(routine, qubits)
    return prog.to_circ(**to_circ_kwargs)


def merge(*circs):
    """Merge the circuits given in parameters in one big circuit.

    :param circs: circuits to merge. All the given circuits should have the same arity
        and will be merged in the order they are given.
    """
    ns = [circ.nbqbits for circ in circs]
    assert all(
        [n == ns[0] for n in ns]
    ), "Incompatible circuits in merge: not the same arity."

    prog = Program()
    _ = prog.qalloc(ns[0])
    circ = prog.to_circ()

    for c in circs:
        circ = circ + c

    return circ


def repeat(circ_to_repeat, repetitions: int):
    """Repeat the same circuit multiple times.

    :param circ_to_repeat: the circuit that should be repeated.
    :param repetitions: the number of times the circuit should be repeated.
    """
    return merge(*[circ_to_repeat for _ in range(repetitions)])


def _get_gate_name(gate, gate_dict=None) -> str:
    """Return the name of the given gate."""
    if gate.is_ctrl:
        if gate_dict:
            return "CTRL({})".format(_get_gate_name(gate_dict[gate.subgate], gate_dict))
        else:
            return "CTRL({})".format(_get_gate_name(gate.subgate))
    elif hasattr(gate, "syntax") and gate.syntax is not None:
        return gate.syntax.name
    elif hasattr(gate, "name") and gate.name is not None:
        return gate.name
    else:
        raise RuntimeError("Unsupported operation in get_gate_name: {}".format(gate))


def compute_gate_count(circuit):
    """Return a dictionnary with gate names as key and the corresponding gate
    count."""
    # 1. Recover the gate names
    gate_names = dict()
    for gate_id in circuit.gateDic:
        gate_names[gate_id] = _get_gate_name(circuit.gateDic[gate_id], circuit.gateDic)
    # 2. Compute the gate count
    gate_count = dict()
    for op in circuit.ops:
        gate_count[gate_names[op.gate]] = gate_count.get(gate_names[op.gate], 0) + 1
    return gate_count


def get_gate_names(circuit):
    """Get the names of the gates appearing in the given cirucit."""
    return compute_gate_count(circuit).keys()
